import os
import gc
import re
import math
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm

# —— 表格比例修正函数（最大余数法） ——
def fix_table_proportions(text: str) -> str:
    match = re.search(r"(\|\s*用地类型.*?)(?:计算：|\Z)", text, flags=re.S)
    if not match:
        return text
    table_block = match.group(1)
    lines = table_block.strip().splitlines()

    if len(lines) < 3:
        print("⚠️ 表格行数不足，跳过比例修正。")
        return text

    header, sep, *rows = lines
    types, vals = [], []

    for row in rows:
        m = re.match(r"\|\s*(.+?)\s*\|\s*([\d\.]+)%\s*\|", row)
        if m:
            types.append(m.group(1))
            vals.append(float(m.group(2)))

    total = sum(vals)
    if total <= 0:
        return text

    raw_percents = [v / total * 100 for v in vals]
    floor_percents = [math.floor(r) for r in raw_percents]
    remainders = [r - f for r, f in zip(raw_percents, floor_percents)]
    to_allocate = 100 - sum(floor_percents)
    idxs = sorted(range(len(remainders)), key=lambda i: remainders[i], reverse=True)
    for i in idxs[:to_allocate]:
        floor_percents[i] += 1

    new_rows = [f"| {t:<22} | {p:>3d}% |" for t, p in zip(types, floor_percents)]
    new_table = "\n".join([header, sep] + new_rows)
    calc_line = "计算：" + "+".join(f"{p}%" for p in floor_percents) + " = 100%"

    return text.replace(table_block, new_table + "\n\n" + calc_line + "\n\n")

# —— 结构校验函数 ——
def validate_scheme_structure(scheme: str) -> bool:
    required_sections = [
        "一、规划思路",
        "二、空间布局优化思路",
        "三、主要更新举措",
        "四、土地利用规划"
    ]
    has_all_titles = all(section in scheme for section in required_sections)
    has_table = bool(re.search(r"\|\s*用地类型\s*\|\s*比例\s*\|", scheme))
    return has_all_titles and has_table

# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载模型
model_dir = '/share/home/zhangshanqi/QSM/cpu/模型保存/基于专家模型的公众奖励模型ppo训练后的专家公众模型/002/融合ppo_outputzhengshibanben02'
print("🚀 开始加载模型和分词器...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True).to(device)
model.eval()
print("✅ 模型加载完成！\n", flush=True)

# 加载数据
data_path = '/share/home/zhangshanqi/QSM/cpu/trl/data/最终评估数据1.xlsx'
df = pd.read_excel(data_path, engine="openpyxl")
if "prompt" not in df.columns:
    raise ValueError("❌ Excel 文件中必须包含名为 'prompt' 的一列")

output_path = '/share/home/zhangshanqi/QSM/cpu/trl/data/02融合模型生成评估数据的方案.xlsx'
print(f"📄 读取到 {len(df)} 条 Prompt，开始生成两个差异化方案...\n", flush=True)

# 生成设置
split_token = "### 以下是改造方案正文："
generate_kwargs = {
    "max_new_tokens": 1000,
    "do_sample": True,
    "top_k": 50,
    "top_p": 0.95,
    "temperature": 1.0,
    "num_return_sequences": 1,
    "eos_token_id": tokenizer.eos_token_id,
    "pad_token_id": tokenizer.pad_token_id,
}

strategy_prompts = [
    "请生成一个更新方案，逻辑完整，突出重点，符合城市更新政策与技术规范。",
]

# 主生成循环
results = []
for idx, prompt in enumerate(tqdm(df['prompt'].astype(str), desc="生成进度"), start=1):
    prompt = prompt.strip()
    tqdm.write(f"\n🟢 Prompt #{idx}: {prompt}\n")
    entry = {"prompt": prompt}

    for i, strat in enumerate(strategy_prompts):
        attempt = 0
        while True:
            attempt += 1
            prompt_text = (
                f"{prompt}\n\n"
                f"{strat}\n\n"
                "请严格按以下结构完整输出（不要放示例，不要出现省略号 ...）：\n\n"
                "一、规划思路（总体定位）\n"
                "二、空间布局优化思路（不得使用第一人称；条目式列出）\n"
                "三、主要更新举措（不得使用第一人称；条目式列出）\n"
                "四、土地利用规划（仅输出真实表格，不得放示例或占位符。表格格式如下：只保留第一行表头，其余行请填写真实内容）\n"
                "| 用地类型 | 比例 |\n"
                "表格要求：至少5行；比例写成百分数并以%结尾；所有比例之和必须为100%；不得出现“...”“—”“空白”。\n"
                f"{split_token}\n\n"
            )
            inputs = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=2048)
            inputs = {k: v.to(device) for k, v in inputs.items()}

            with torch.no_grad():
                output_ids = model.generate(**inputs, **generate_kwargs)
            text_out = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)

            if split_token in text_out:
                scheme = text_out.split(split_token, 1)[1].strip()
            else:
                scheme = text_out.strip()

            # 清理 markdown 和表格格式
            scheme = re.sub(r"^```[^\n]*\n", "", scheme)
            scheme = re.sub(r"\n```\s*$", "", scheme)
            scheme = re.sub(r"```.*?```", "", scheme, flags=re.S)

            if validate_scheme_structure(scheme):
                scheme = fix_table_proportions(scheme)
                print(f"✅ 第 {attempt} 次尝试成功生成结构完整方案。\n")
                break
            else:
                print(f"❌ 第 {attempt} 次生成不合格，重试中...")

        label = f"方案{i+1}"
        tqdm.write(f"—— {label} ——")
        tqdm.write(scheme + "\n")
        entry[label] = scheme

    results.append(entry)

    if idx % 10 == 0:
        pd.DataFrame(results).to_excel(output_path, index=False, engine="openpyxl")
        tqdm.write(f"✅ 已保存前 {idx} 条结果至：{output_path}")

    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

# 最终保存
pd.DataFrame(results).to_excel(output_path, index=False, engine="openpyxl")
tqdm.write(f"\n🎉 所有方案已生成完毕，结果保存至：{output_path}")
